from src.datasets.dataset_creator.dsl_0 import *
from src.datasets.dataset_creator.constants import *
from src.datasets.dataset_creator.arc_types import *
import numpy as np
import heapq
from src.datasets.dataset_creator.input_generation.v1_agent import *
from collections import deque


def snake_move_program(inputs, key) -> Grid:
    """
    Simulates the snake moving towards and eating the apple.
    Ensures there is always exactly one tail and the head remains yellow when it eats the apple.

    Args:
        grid (Grid): The input grid with snake and apple

    Returns:
        Grid: The final grid after the snake has reached the apple
    """
    grid, _ = inputs

    grid = tuple(tuple(row) for row in grid)

    grid_size = len(grid)
    BLACK, RED, GREEN, ORANGE, PURPLE = 0, 2, 3, 4, 5  # Color indices

    def find_snake_and_apple():
        head, tail, body_parts, apple = None, None, set(), None
        for i in range(grid_size):
            for j in range(grid_size):
                if grid[i][j] == ORANGE:
                    head = (i, j)
                elif grid[i][j] == PURPLE:
                    tail = (i, j)
                elif grid[i][j] == GREEN:
                    body_parts.add((i, j))
                elif grid[i][j] == RED:
                    apple = (i, j)

        if not head or not tail or not apple:
            raise ValueError(f"Invalid grid: missing head ({head}), tail ({tail}), or apple ({apple})")

        def get_neighbors(pos):
            return [
                (pos[0] + dx, pos[1] + dy)
                for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1)]
                if 0 <= pos[0] + dx < grid_size and 0 <= pos[1] + dy < grid_size
            ]

        def reconstruct_snake(current, path, remaining_body):
            if current == tail and not remaining_body:
                return path + [tail]

            neighbors = get_neighbors(current)
            valid_neighbors = [
                n for n in neighbors if n in remaining_body or (n == tail and not remaining_body)
            ]

            for neighbor in valid_neighbors:
                new_path = path + [current]
                new_remaining = remaining_body - {neighbor} if neighbor != tail else remaining_body

                result = reconstruct_snake(neighbor, new_path, new_remaining)
                if result:
                    return result

            return None

        snake = reconstruct_snake(head, [], body_parts)

        if not snake:
            raise ValueError(
                f"Unable to reconstruct a valid snake. Head: {head}, Tail: {tail}, Body parts: {body_parts}"
            )

        return snake, apple

    def move_snake(grid, snake, apple):
        head = snake[0]
        # Determine movement direction
        dx = 1 if apple[0] > head[0] else -1 if apple[0] < head[0] else 0
        dy = 1 if apple[1] > head[1] else -1 if apple[1] < head[1] else 0

        # Calculate new head position
        new_head = (head[0] + dx, head[1] + dy)

        # Move snake
        new_snake = [new_head] + snake[:-1]

        # Update grid
        grid = fill(grid, BLACK, (snake[-1],))  # Remove old tail
        grid = fill(grid, ORANGE, (new_head,))  # New head
        grid = fill(grid, GREEN, (snake[0],))  # Old head becomes body
        if len(new_snake) > 1:
            grid = fill(grid, PURPLE, (new_snake[-1],))  # New tail

        return grid, new_snake

    try:
        snake, apple = find_snake_and_apple()
    except ValueError as e:
        print(f"Error in find_snake_and_apple: {e}")
        print("Grid state:")
        for row in grid:
            print(" ".join(str(cell) for cell in row))
        raise

    while snake[0] != apple:
        grid, snake = move_snake(grid, snake, apple)

    # Once the apple is reached
    grid = fill(grid, ORANGE, (apple,))  # Head color (yellow) on apple position
    if len(snake) > 1:
        grid = fill(grid, GREEN, (snake[1],))  # Second segment becomes body
        grid = fill(grid, PURPLE, (snake[-1],))  # Ensure tail is purple

    return np.array(grid)


def maze_solve_program(inputs, key) -> Grid:
    """
    Solves the maze puzzle by finding a path from the agent (green) to the center (blue),
    and draws the solution path in red.

    Args:
        grid (Grid): The input grid with the maze puzzle

    Returns:
        Grid: The grid with the solution path drawn in red
    """
    grid, _ = inputs

    grid = np.array(grid)
    WHITE, BLACK, RED, GREEN, BLUE = 0, 1, 2, 3, 6  # Color indices

    def find_agent_and_center():
        agent = (0, 1)  # Agent is always at (0, 1)
        center = None
        for i in range(len(grid)):
            for j in range(len(grid[0])):
                if grid[i][j] == BLUE:
                    center = (i, j)
                    return agent, center
        return agent, (len(grid) // 2, len(grid) // 2)  # Fallback center position

    def get_neighbors(pos):
        directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
        return [
            (pos[0] + d[0], pos[1] + d[1])
            for d in directions
            if 0 <= pos[0] + d[0] < len(grid)
            and 0 <= pos[1] + d[1] < len(grid[0])
            and grid[pos[0] + d[0]][pos[1] + d[1]] != BLACK
        ]

    def bfs(start, goal):
        queue = [(start, [start])]
        visited = set([start])

        while queue:
            (node, path) = queue.pop(0)
            if node == goal:
                return path

            for neighbor in get_neighbors(node):
                if neighbor not in visited:
                    visited.add(neighbor)
                    queue.append((neighbor, path + [neighbor]))

        return []  # Return empty path if no solution found

    agent, center = find_agent_and_center()
    path = bfs(agent, center)

    for pos in path[1:-1]:  # Skip agent and center positions
        if grid[pos[0]][pos[1]] != BLUE:  # Don't overwrite the blue center
            grid[pos[0]][pos[1]] = RED

    return grid


class PathNotFoundException(Exception):
    pass


def city_network_shortest_path(inputs, key):
    grid, _ = inputs

    WHITE, BLACK, START, END = 0, 1, 2, 3
    PATH = 4

    grid = np.array(grid)

    def find_start_end():
        start, end = None, None
        for i in range(len(grid)):
            for j in range(len(grid[0])):
                if grid[i, j] == START:
                    start = (i, j)
                elif grid[i, j] == END:
                    end = (i, j)
                if start and end:
                    return start, end
        raise ValueError("Start or end point not found in the grid")

    def get_neighbors(pos):
        # Order: Left, Right, Up, Down
        directions = [(0, -1), (0, 1), (-1, 0), (1, 0)]
        for dx, dy in directions:
            nx, ny = pos[0] + dx, pos[1] + dy
            if 0 <= nx < len(grid) and 0 <= ny < len(grid[0]):
                yield (nx, ny)

    def heuristic(a, b):
        # Manhattan distance
        return abs(a[0] - b[0]) + abs(a[1] - b[1])

    def cost(current, next):
        if grid[next] == WHITE:
            return float("inf")
        # Slightly prefer horizontal movements
        return 0.99 if current[0] == next[0] else 1

    def a_star(start, goal):
        frontier = []
        counter = 0
        heapq.heappush(frontier, (0, counter, start))
        came_from = {start: None}
        cost_so_far = {start: 0}

        while frontier:
            _, _, current = heapq.heappop(frontier)

            if current == goal:
                return reconstruct_path(came_from, start, goal)

            for next in get_neighbors(current):
                new_cost = cost_so_far[current] + cost(current, next)
                if next not in cost_so_far or new_cost < cost_so_far[next]:
                    cost_so_far[next] = new_cost
                    priority = new_cost + heuristic(goal, next)
                    counter += 1
                    heapq.heappush(frontier, (priority, counter, next))
                    came_from[next] = current

        raise PathNotFoundException("No path found from start to end")

    def reconstruct_path(came_from, start, goal):
        current = goal
        path = []
        while current != start:
            path.append(current)
            current = came_from[current]
        path.append(start)
        path.reverse()
        return path

    start, end = find_start_end()
    try:
        path = a_star(start, end)
    except PathNotFoundException:
        raise

    for pos in path[1:-1]:
        grid[pos] = PATH

    return grid


def flood_fill_solve(inputs, key) -> Grid:
    """
    Solve the flood fill puzzle by filling connected water cells.

    Args:
        grid (Grid): The input grid with the starting point

    Returns:
        Grid: The grid with the flood fill completed
    """
    grid, _ = inputs
    grid = np.array(grid)
    height, width = grid.shape

    def find_start():
        return next((i, j) for i in range(height) for j in range(width) if grid[i, j] == 9)

    start = find_start()
    fill_color = 7  # Dark blue for flooded area

    def flood_fill(x, y):
        if not (0 <= x < height and 0 <= y < width):
            return
        if grid[x, y] not in [0, 1, 9]:  # Only fill water
            return
        if grid[x, y] == fill_color:
            return

        grid[x, y] = fill_color

        flood_fill(x + 1, y)
        flood_fill(x - 1, y)
        flood_fill(x, y + 1)
        flood_fill(x, y - 1)

    flood_fill(start[0], start[1])

    return grid


def checkers_capture_solve(inputs, key) -> np.array:
    """
    Perform the single available capture on the Checkers board.

    Args:
        board (np.array): The input Checkers board

    Returns:
        np.array: The Checkers board after the capture has been performed
    """
    board, _ = inputs

    for row in range(8):
        for col in range(8):
            if board[row, col] == 1:  # Black piece
                # Check all possible capture directions
                for dx, dy in [(-1, -1), (-1, 1), (1, -1), (1, 1)]:
                    if 0 <= row + 2 * dy < 8 and 0 <= col + 2 * dx < 8:
                        if board[row + dy, col + dx] == 2 and board[row + 2 * dy, col + 2 * dx] == 0:
                            # Perform capture
                            new_board = board.copy()
                            new_board[row, col] = 0  # Original position now empty
                            new_board[row + dy, col + dx] = 0  # Remove captured piece
                            new_board[row + 2 * dy, col + 2 * dx] = 1  # Move black piece to new position
                            return new_board

    return board  # Return original board if no capture found (shouldn't happen)


def light_bulb_solve(inputs, key) -> np.array:
    """
    Solve the light bulb illumination puzzle by marking illuminated cells.

    Args:
        grid (np.array): The input grid with walls and light bulbs

    Returns:
        np.array: The grid with illuminated cells marked (3)
    """
    grid, _ = inputs

    height, width = grid.shape
    illuminated = grid.copy()

    def illuminate(x, y, dx, dy):
        while 0 <= x < height and 0 <= y < width:
            if illuminated[x, y] == 1:  # Wall
                break
            if illuminated[x, y] == 0:
                illuminated[x, y] = 3  # Mark as illuminated
            x += dx
            y += dy

    for x in range(height):
        for y in range(width):
            if grid[x, y] == 2:  # Light bulb
                # Illuminate in all four directions
                illuminate(x, y, 1, 0)  # Down
                illuminate(x, y, -1, 0)  # Up
                illuminate(x, y, 0, 1)  # Right
                illuminate(x, y, 0, -1)  # Left

    return illuminated


def radio_coverage_solve(inputs, key) -> np.array:
    """
    Color the grid based on shortest path distance to the nearest radio tower,
    accounting for obstacles and maximum range.

    Args:
        grid (np.array): The input grid with obstacles and radio towers

    Returns:
        np.array: The grid colored based on shortest path distance to nearest tower
    """
    grid, _ = inputs

    height, width = grid.shape
    coverage = np.full_like(grid, fill_value=np.inf, dtype=float)

    MAX_RANGE = 6  # Maximum range (using 6 so that distance 7 and greater stays black)

    def bfs(start_x, start_y):
        queue = deque([(start_x, start_y, 0)])
        visited = set([(start_x, start_y)])

        while queue:
            x, y, dist = queue.popleft()

            if dist > MAX_RANGE:
                continue

            if grid[x, y] != 1:  # Only update if not an obstacle
                coverage[x, y] = min(coverage[x, y], dist)

            for dx, dy in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
                nx, ny = x + dx, y + dy
                if 0 <= nx < height and 0 <= ny < width and (nx, ny) not in visited and grid[nx, ny] != 1:
                    visited.add((nx, ny))
                    queue.append((nx, ny, dist + 1))

    # Find all towers and run BFS from each
    for x in range(height):
        for y in range(width):
            if grid[x, y] == 2:  # Radio tower
                bfs(x, y)

    # Convert distances to color ranges (3-7), keeping obstacles and out-of-range cells unchanged
    result = np.zeros_like(grid)
    for x in range(height):
        for y in range(width):
            if grid[x, y] == 1:  # Obstacle
                result[x, y] = 1
            elif coverage[x, y] <= MAX_RANGE:
                result[x, y] = max(3, 7 - coverage[x, y])
            # Cells with distance > MAX_RANGE or unreachable remain 0 (black)

    return result


TRANSFORMATIONS = [
    {
        "name": "snake",
        "input_generator": create_snake_input,
        "program": snake_move_program,
        "frequency_weight": 1,
    },
    {
        "name": "maze",
        "input_generator": create_maze_input,
        "program": maze_solve_program,
        "frequency_weight": 1,
    },
    {
        "name": "city_network",
        "input_generator": create_city_network_input,
        "program": city_network_shortest_path,
        "frequency_weight": 1,
    },
    {
        "name": "flood_fill",
        "input_generator": create_flood_fill_input,
        "program": flood_fill_solve,
        "frequency_weight": 1,
    },
    {
        "name": "checkers_capture",
        "input_generator": create_checkers_input,
        "program": checkers_capture_solve,
        "frequency_weight": 1,
    },
    {
        "name": "light_bulb_illumination",
        "input_generator": create_light_bulb_input,
        "program": light_bulb_solve,
        "frequency_weight": 1,
    },
    {
        "name": "radio_signal_coverage",
        "input_generator": create_radio_coverage_input,
        "program": radio_coverage_solve,
        "frequency_weight": 1,
    },
]
